import os
from pathlib import Path
import argparse
from typing import Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import evaluate
import nltk
from tqdm import tqdm
import torch
import random
import json
import copy
import intel_extension_for_pytorch as ipex
import logging
import time
from datasets import load_dataset
import numpy as np
from torch.utils.data import DataLoader

parser = argparse.ArgumentParser("INT4 GPT-J on CNNDAILYMAIL", add_help=False)
parser.add_argument(
    "-m",
    "--model",
    type=str,
    default="EleutherAI/gpt-j-6B",
    help="the huggingface mdoel id or your local cache of model",
)
parser.add_argument(
    "--dataset-path",
    default="",
    type=str,
    help="Json file path for validation dataset, e.g., cnn_dailymail_validation.json",
)
parser.add_argument(
    "--low-precision-checkpoint",
    default="",
    type=str,
    help="Low precision checkpoint file generated by calibration with GPTQ. "
    "If provided, calibration with GPTQ is skipped.",
)
parser.add_argument(
    "--int4-model",
    default="",
    type=str,
    help="the INT4 model file path. If provided, calibration with GPTQ and quantization are skipped.",
)
parser.add_argument("--output-dir", nargs="?", default="./saved_results")
parser.add_argument(
    "--fp32",
    action="store_true",
    help="Run float32 model without quantization. Cannot use this option along with --bf16.",
)
parser.add_argument(
    "--bf16",
    action="store_true",
    help="Run bfloat16 model without quantization. Cannot use this option along with --fp32.",
)
args = parser.parse_args()

assert not (
    args.fp32 and args.bf16
), "--fp32 and --bf16 cannot be used at the same time"

random.seed(9973)
logger = logging.getLogger("INT4 GPT-J")
logger.setLevel(logging.INFO)
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


class CNNDAILYMAIL(object):
    def __init__(
        self,
        model_path,
        data_path,
        device="cpu",
        is_calib=False,
        num_samples=20,
        max_len=1920,
    ):
        self.model_path = model_path
        self.data_path = data_path
        self.device = device
        self.num_samples = num_samples
        self.is_calib = is_calib

        self.padding = "max_length" if self.is_calib else False
        self.max_len = 2048 if self.is_calib else max_len

        self.calib_collator = self.collate_batch
        self.pad_max = max_len
        self.load_tokenizer()
        self.load_dataset()

    def load_dataset(self):
        """Loads dataset"""
        with open(self.data_path, "r") as fid:
            list_data_dict = json.load(fid)
            self.list_data_dict = copy.deepcopy(list_data_dict)

        if self.num_samples is not None:
            self.num_samples = min(self.num_samples, len(list_data_dict))

            if self.is_calib:
                list_data_dict = list_data_dict[: self.num_samples]
            else:
                list_data_dict = random.choices(list_data_dict, k=self.num_samples)

        prompt_input, prompt_no_input = (
            PROMPT_DICT["prompt_input"],
            PROMPT_DICT["prompt_no_input"],
        )
        sources = [prompt_input.format_map(example) for example in list_data_dict]
        targets = [f"{example['output']}" for example in list_data_dict]

        self.input_ids = []
        self.input_lens = []
        self.attention_mask = []
        for i in range(len(sources)):
            tok_input = self.tokenize_function(sources[i])
            self.input_ids.append(tok_input.input_ids)
            self.attention_mask.append(tok_input.attention_mask)

        self.sources = sources
        self.targets = targets

    def load_tokenizer(self):
        """Returns the tokenizer"""
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_path,
            model_max_length=2048,
            padding_side="right",
            use_fast=False,
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token

    @torch.no_grad()
    def tokenize_function(self, text):
        example = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
            padding=self.padding,
        )
        return example

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        input_ids = self.input_ids[i]
        input_len = input_ids.shape[-1]
        attention_mask = self.attention_mask[i]
        return (input_ids, input_len, attention_mask)

    @torch.no_grad()
    def collate_batch(self, batch):
        input_ids_padded = []
        attention_mask_padded = []
        for input_ids, input_len, attention_mask in batch:
            input_ids_padded.append(input_ids)
            attention_mask_padded.append(attention_mask)

        input_ids_padded = torch.vstack(input_ids_padded)
        attention_mask_padded = torch.vstack(attention_mask_padded)
        return (input_ids_padded, attention_mask_padded)


nltk.download("punkt", quiet=False)
metric = evaluate.load("rouge")

parent_path = Path(__file__).parent.absolute()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)


def load_original_model(args):
    logger.info("Loading model {}...".format(args.model))
    config = AutoConfig.from_pretrained(args.model, torchscript=True)
    user_model = AutoModelForCausalLM.from_pretrained(
        args.model, torch_dtype=torch.float, config=config, low_cpu_mem_usage=True
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    logger.info("model loaded.")
    return user_model, tokenizer


dataset_id = "cnn_dailymail"
dataset_version = "3.0.0"
dataset_split = "validation"
if args.dataset_path == "":
    instruction_template = "Summarize the following news article:"
    logger.info("Loading {} split of {} dataset...".format(dataset_split, args.model))
    dataset = load_dataset(dataset_id, name=dataset_version, split=dataset_split)
    train = dict((x["id"], x) for x in dataset)
    inputs = []
    for i in tqdm(range(len(dataset))):
        sample = dataset[i]
        x = dict()
        x["instruction"] = instruction_template
        x["input"] = sample["article"]
        x["output"] = sample["highlights"]
        inputs.append(x)

    val_data_path = os.path.join(
        args.output_dir, "cnn_dailymail_{}.json".format(dataset_split)
    )
    with open(val_data_path, "w") as write_f:
        json.dump(inputs, write_f, indent=4, ensure_ascii=False)

    logger.info("{} data saved at {}".format(dataset_split, val_data_path))
else:
    logger.info("Use the given dataset {}".format(args.dataset_path))
    val_data_path = args.dataset_path

num_beams = 4
batch_size = 1
if args.fp32 or args.bf16:
    user_model, tokenizer = load_original_model(args)
    logger.info("Optimize model by ipex.llm.optimize")
    user_model = user_model.eval()
    user_model = user_model.to(memory_format=torch.channels_last)
    inf_dtype = torch.float if args.fp32 else torch.bfloat16
    user_model = ipex.llm.optimize(
        user_model.eval(),
        dtype=inf_dtype,
        inplace=True,
        deployment_mode=True,
    )
elif args.int4_model == "":
    if args.low_precision_checkpoint == "":
        logger.info("Do calibration with GPTQ to generate lowp-precision checkpoint.")
        logger.info("Calibration with GPTQ will take an hour or so. Please wait.")
        user_model, tokenizer = load_original_model(args)
        calib_iters = 128
        calib_dataset = CNNDAILYMAIL(
            args.model, val_data_path, is_calib=True, num_samples=calib_iters
        )
        calib_dataloader = DataLoader(
            calib_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=calib_dataset.collate_batch,
        )

        compressed_model = ipex.quantization.gptq(
            model=user_model,
            dataloader=calib_dataloader,
            group_size=128,
            use_max_length=True,
            compression_dtype=torch.int32,
            compression_dim=1,
            scale_dtype=torch.float16,
            save_dir=args.output_dir,
        )

        logger.info(
            "Calibration finished. Low-precision checkpoint generated as {}.".format(
                args.output_dir
            )
        )
        # Quit here because we want to use different environment variables to run GPTQ and benchmark.
        # So, run this script twice and specify the GPTQ checkpoint file for the second run.
        quit()
    else:
        logger.info("low_precision_checkpoint is given. Calibration skipped.")
        low_precision_checkpoint_file_path = args.low_precision_checkpoint

    logger.info("Loading low_precision_checkpoint...")
    low_precision_checkpoint = torch.load(low_precision_checkpoint_file_path)
    config_dict = {
        "weight_key": "qweight",
        "scale_key": "scales",
        "zero_point_key": "qzeros",
        "bias_key": "bias",
        "g_idx_key": "g_idx",
    }
    state_dict_and_config = (low_precision_checkpoint, config_dict)
    logger.info("low_precision_checkpoint loaded.")

    user_model, tokenizer = load_original_model(args)

    logger.info("Quantize model to INT4.")
    beam_idx_tmp = torch.zeros(
        (2048, int(batch_size * num_beams)), dtype=torch.long
    ).contiguous()
    global_past_key_value = [
        (
            torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
            torch.zeros(
                [
                    1,
                    user_model.config.num_attention_heads,
                    1,
                    int(
                        user_model.config.hidden_size
                        / user_model.config.num_attention_heads
                    ),
                ]
            ).contiguous(),
            torch.zeros(
                [
                    1,
                    user_model.config.num_attention_heads,
                    1,
                    int(
                        user_model.config.hidden_size
                        / user_model.config.num_attention_heads
                    ),
                ]
            ).contiguous(),
            beam_idx_tmp,
        )
        for i in range(user_model.config.num_hidden_layers)
    ]
    weight_dtype = ipex.quantization.WoqWeightDtype.INT4
    lowp_mode = ipex.quantization.WoqLowpMode.INT8
    qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping(
        weight_dtype=weight_dtype, lowp_mode=lowp_mode
    )
    logger.info("Start quantizing model to INT4 by ipex.llm.optimize.")
    user_model = ipex.llm.optimize(
        user_model.eval(),
        dtype=torch.bfloat16,
        quantization_config=qconfig_mapping,
        inplace=True,
        low_precision_checkpoint=state_dict_and_config,
        deployment_mode=False,
    )
    example_inputs = None
    input_ids = torch.ones(32).to(torch.long)
    attention_mask = torch.ones(len(input_ids))
    position_ids = torch.arange(len(input_ids))
    example_inputs = (
        input_ids.unsqueeze(0),
        attention_mask.unsqueeze(0),
        tuple(global_past_key_value),
        position_ids.unsqueeze(0),
    )
    with torch.no_grad(), torch.cpu.amp.autocast(enabled=True):
        self_jit = torch.jit.trace(user_model.eval(), example_inputs, strict=False)
        self_jit = torch.jit.freeze(self_jit.eval())
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
        self_jit.save(args.output_dir + "/int4_model.pt")
    logger.info(
        "Quantization finished. INT4 model saved to {}.".format(
            args.output_dir + "/int4_model.pt"
        )
    )
else:
    user_model, tokenizer = load_original_model(args)
    logger.info("INT4 model is given. Quantization skipped.")
    logger.info("Loading INT4 model...")
    self_jit = torch.jit.load(args.int4_model)
    self_jit = torch.jit.freeze(self_jit.eval())
    ipex._set_optimized_model_for_generation(user_model, optimized_model=self_jit)
    logger.info("INT4 model loaded.")

logger.info("Ready to run accuracy task.")
generate_kwargs = {
    "early_stopping": True,
    "max_new_tokens": 128,
    "min_new_tokens": 30,
    "num_beams": num_beams,
}
max_len = 1919
preds = []
predictions = []
ground_truths = []


def postprocess_text(preds, targets):
    preds = [pred.strip() for pred in preds]
    targets = [target.strip() for target in targets]

    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]

    return preds, targets


# Only run 1000 samples. It saves a lot of time and it's a good approximation of results on the whole dataset
iters = 1000
val_dataset = CNNDAILYMAIL(
    args.model, val_data_path, is_calib=False, max_len=max_len, num_samples=iters
)
sources = val_dataset.sources
targets = val_dataset.targets
logger.info("Start running accuracy task...")
logger.info("Number of samples to run = {}".format(iters))
with torch.inference_mode(), torch.no_grad(), torch.cpu.amp.autocast(
    enabled=(False if args.fp32 else True),
    dtype=(None if args.fp32 else torch.bfloat16),
):
    for i in tqdm(range(len(sources))):
        input_ids, actual_lens, att_mask = val_dataset[i]
        input_lens = input_ids.shape[-1]
        t0 = time.time()
        out_tokens = user_model.generate(
            input_ids,
            attention_mask=att_mask,
            **generate_kwargs,
            pad_token_id=tokenizer.pad_token_id,
        )
        t1 = time.time()
        print("Inference time: {}".format(round(t1 - t0, 3)))
        print("Seq len: {}".format(input_ids.shape[-1]))
        print("Actual token len: {}".format(actual_lens))
        print("Out len: {}".format(out_tokens.shape[-1] - input_ids.shape[-1]))

        pred = out_tokens[:, input_lens:]
        pred_batch = tokenizer.batch_decode(pred, skip_special_tokens=True)

        targ_batch = targets[i : i + 1]  # batchsize=1
        preds, targs = postprocess_text(pred_batch, targ_batch)
        predictions.extend(preds)
        ground_truths.extend(targs)
        if i == iters - 1:
            break

result = metric.compute(
    predictions=predictions,
    references=ground_truths,
    use_stemmer=True,
    use_aggregator=False,
)
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
logger.info("Accuracy test results:")
logger.info(result)
